Sheet 8

Anna Sommani, Steffen Albert


In [189]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

In [199]:
#load data from file
import scipy.io as sio
data=sio.loadmat('Tut8_file1.mat')
state = data['state'].squeeze()
reward = data['reward'].squeeze()
response = data['response'].squeeze()
T = len(state)

In [200]:
#do meth stuff
V11 = np.zeros(T+1)
V12 = np.zeros(T+1)
V21 = np.zeros(T+1)
V22 = np.zeros(T+1)

for t in range(1,T+1):
    V11[t], V12[t], V21[t], V22[t] = V11[t-1], V12[t-1], V21[t-1], V22[t-1]
    if state[t-1] == 1 and response[t-1] == 1:
        V11[t]=V11[t-1]+0.1*(reward[t-1]-V11[t-1])
    if state[t-1] == 1 and response[t-1] == 2:
        V12[t]=V12[t-1]+0.1*(reward[t-1]-V12[t-1])
    if state[t-1] == 2 and response[t-1] == 1:
        V21[t]=V21[t-1]+0.1*(reward[t-1]-V21[t-1])
    if state[t-1] == 2 and response[t-1] == 2:
        V22[t]=V22[t-1]+0.1*(reward[t-1]-V22[t-1])

In [206]:
plt.plot(V11, label='$V_{11}$')
plt.plot(V12, label='$V_{12}$')
plt.legend()


Out[206]:
<matplotlib.legend.Legend at 0x119d90780>

In [202]:
plt.plot(V21, label='$V_{21}$')
plt.plot(V22, label='$V_{22}$')
plt.plot((68, 68), (0, 1), 'r:')
plt.legend()


Out[202]:
<matplotlib.legend.Legend at 0x118660a90>

In [203]:
for t in range(T):
    if (state[t] != response[t] and reward[t] == 1) or (state[t] == response[t] and reward[t] == 0):
        print(f't = {t+1:3d} for state {state[t]} and response {response[t]} and reward {reward[t]}')
    if state[t] == 1 and response[t] == 2 and reward[t] == 1:
        print(f't = {t+1:3d} for state {state[t]} and response {response[t]} and reward {reward[t]}')


t =  68 for state 2 and response 2 and reward 0
t =  71 for state 2 and response 2 and reward 0
t =  72 for state 2 and response 2 and reward 0
t =  74 for state 2 and response 2 and reward 0
t =  76 for state 2 and response 2 and reward 0
t =  79 for state 2 and response 2 and reward 0
t =  80 for state 2 and response 2 and reward 0
t =  83 for state 2 and response 1 and reward 1
t =  84 for state 2 and response 2 and reward 0
t =  86 for state 2 and response 2 and reward 0
t =  88 for state 2 and response 2 and reward 0
t =  90 for state 2 and response 2 and reward 0
t =  92 for state 2 and response 1 and reward 1
t =  94 for state 2 and response 1 and reward 1
t =  97 for state 2 and response 1 and reward 1
t =  99 for state 2 and response 2 and reward 0
t = 100 for state 2 and response 2 and reward 0
t = 102 for state 2 and response 2 and reward 0
t = 105 for state 2 and response 2 and reward 0
t = 107 for state 2 and response 2 and reward 0
t = 109 for state 2 and response 1 and reward 1
t = 111 for state 2 and response 1 and reward 1
t = 112 for state 2 and response 1 and reward 1
t = 114 for state 2 and response 1 and reward 1
t = 117 for state 2 and response 1 and reward 1
t = 119 for state 2 and response 1 and reward 1
t = 120 for state 2 and response 2 and reward 0
t = 123 for state 2 and response 1 and reward 1
t = 125 for state 2 and response 1 and reward 1

So at first, the reward was given if the reponse was the same as the state, for the state 1, it stayed that way for the whole time and the rat learned it. For the state 2, the meaning of the light was switched in experiment 68, so that the rat had to press the other lever to get a reward, which it learned also.


In [278]:
def LL(state, reward, response ,a,b):
    T = len(state)
    
    V11 = np.zeros(T+1)
    V12 = np.zeros(T+1)
    V21 = np.zeros(T+1)
    V22 = np.zeros(T+1)
    
    LL = - np.log(2)

    for t in range(1,T+1):
        V11[t], V12[t], V21[t], V22[t] = V11[t-1], V12[t-1], V21[t-1], V22[t-1]
        if state[t-1] == 1 and response[t-1] == 1:
            V11[t]=V11[t-1]+a*(reward[t-1]-V11[t-1])
            LL += b * V11[t]
        if state[t-1] == 1 and response[t-1] == 2:
            V12[t]=V12[t-1]+a*(reward[t-1]-V12[t-1])
            LL += b * V12[t]
        if state[t-1] == 2 and response[t-1] == 1:
            V21[t]=V21[t-1]+a*(reward[t-1]-V21[t-1])
            LL += b * V21[t]
        if state[t-1] == 2 and response[t-1] == 2:
            V22[t]=V22[t-1]+a*(reward[t-1]-V22[t-1])
            LL += b * V22[t]
            
        if state[t-1] == 1:
            LL += -np.log(np.exp(b*V11[t])+np.exp(b*V12[t]))
        else:
            LL += -np.log(np.exp(b*V21[t])+np.exp(b*V22[t]))
    
    return LL

alpha = np.linspace(0,1,201)
beta = np.linspace(0,3,601)
data = np.array([[LL(state, reward, response ,a,b) for b in beta] for a in alpha])
print(data)


[[ -87.33654475  -87.33654475  -87.33654475 ...,  -87.33654475
   -87.33654475  -87.33654475]
 [ -87.33654475  -87.3208955   -87.30525445 ...,  -79.43070029
   -79.41988397  -79.40907552]
 [ -87.33654475  -87.30710289  -87.27768859 ...,  -74.54232665  -74.528658
   -74.51501351]
 ..., 
 [ -87.33654475  -87.20433413  -87.07285375 ..., -107.50454602
  -107.63628971 -107.7681662 ]
 [ -87.33654475  -87.20437244  -87.07293087 ..., -107.58198577
  -107.71388682 -107.84592065]
 [ -87.33654475  -87.20441038  -87.07300724 ..., -107.65879811
  -107.79085507 -107.92304476]]

In [282]:
#find maximum
maximum = np.where(data == np.amax(data))
maxalpha = alpha[maximum[0][0]]
maxbeta = beta[maximum[1][0]]
print(f'The maximum likelihood is at alpha = {maxalpha:.2f} and ' +
f'beta = {maxbeta:.2f} with a log likelihood of {np.amax(data):.2f}')

plt.rcParams['figure.figsize'] = [18.0, 6.0]
import matplotlib.cm as cm
plt.imshow(data[::-1], extent=(np.min(beta), np.max(beta), np.min(alpha), np.max(alpha)), cmap=cm.gist_rainbow)
plt.xlabel('beta')
plt.ylabel('alpha')
plt.colorbar()

#plot maximum
plt.scatter(x=[maxbeta+0.025], y=[maxalpha+0.025], c='w', s=40)

#plt.colorbar()
plt.rcParams['figure.figsize'] = [8.0, 6.0]


The maximum likelihood is at alpha = 0.08 and beta = 2.19 with a log likelihood of -64.56

In [319]:
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter


fig = plt.figure()
ax = fig.gca(projection='3d')

# Make data.
X = beta
Y = alpha
X, Y = np.meshgrid(X, Y)
Z = data

# Plot the surface.
surf = ax.plot_surface(X, Y, Z, cmap=cm.gist_rainbow,
                       linewidth=0, antialiased=True)

# Customize the z axis.
ax.set_zlim(np.amin(data), np.amax(data))
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

# Add a color bar which maps values to colors.
fig.colorbar(surf, shrink=0.5, aspect=5)

plt.show()



In [ ]: